{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "## 06子词嵌入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.604170Z",
     "iopub.status.busy": "2023-08-18T06:56:35.603510Z",
     "iopub.status.idle": "2023-08-18T06:56:35.611979Z",
     "shell.execute_reply": "2023-08-18T06:56:35.611231Z"
    },
    "origin_pos": 1,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',\n",
    "           'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',\n",
    "           '_', '[UNK]']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.615843Z",
     "iopub.status.busy": "2023-08-18T06:56:35.615201Z",
     "iopub.status.idle": "2023-08-18T06:56:35.623942Z",
     "shell.execute_reply": "2023-08-18T06:56:35.623209Z"
    },
    "origin_pos": 4,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}\n",
    "token_freqs = {}\n",
    "for token, freq in raw_token_freqs.items():\n",
    "    token_freqs[' '.join(list(token))] = raw_token_freqs[token]\n",
    "token_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.627616Z",
     "iopub.status.busy": "2023-08-18T06:56:35.627025Z",
     "iopub.status.idle": "2023-08-18T06:56:35.631950Z",
     "shell.execute_reply": "2023-08-18T06:56:35.631221Z"
    },
    "origin_pos": 6,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def get_max_freq_pair(token_freqs):\n",
    "    pairs = collections.defaultdict(int)\n",
    "    for token, freq in token_freqs.items():\n",
    "        symbols = token.split()\n",
    "        for i in range(len(symbols) - 1):\n",
    "            # “pairs”的键是两个连续符号的元组\n",
    "            pairs[symbols[i], symbols[i + 1]] += freq\n",
    "    return max(pairs, key=pairs.get)  # 具有最大值的“pairs”键"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.635554Z",
     "iopub.status.busy": "2023-08-18T06:56:35.634913Z",
     "iopub.status.idle": "2023-08-18T06:56:35.639631Z",
     "shell.execute_reply": "2023-08-18T06:56:35.638892Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def merge_symbols(max_freq_pair, token_freqs, symbols):\n",
    "    symbols.append(''.join(max_freq_pair))\n",
    "    new_token_freqs = dict()\n",
    "    for token, freq in token_freqs.items():\n",
    "        new_token = token.replace(' '.join(max_freq_pair),\n",
    "                                  ''.join(max_freq_pair))\n",
    "        new_token_freqs[new_token] = token_freqs[token]\n",
    "    return new_token_freqs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.643247Z",
     "iopub.status.busy": "2023-08-18T06:56:35.642643Z",
     "iopub.status.idle": "2023-08-18T06:56:35.647847Z",
     "shell.execute_reply": "2023-08-18T06:56:35.647061Z"
    },
    "origin_pos": 10,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "合并# 1: ('t', 'a')\n",
      "合并# 2: ('ta', 'l')\n",
      "合并# 3: ('tal', 'l')\n",
      "合并# 4: ('f', 'a')\n",
      "合并# 5: ('fa', 's')\n",
      "合并# 6: ('fas', 't')\n",
      "合并# 7: ('e', 'r')\n",
      "合并# 8: ('er', '_')\n",
      "合并# 9: ('tall', '_')\n",
      "合并# 10: ('fast', '_')\n"
     ]
    }
   ],
   "source": [
    "num_merges = 10\n",
    "for i in range(num_merges):\n",
    "    max_freq_pair = get_max_freq_pair(token_freqs)\n",
    "    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)\n",
    "    print(f'合并# {i+1}:',max_freq_pair)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.651408Z",
     "iopub.status.busy": "2023-08-18T06:56:35.650818Z",
     "iopub.status.idle": "2023-08-18T06:56:35.654893Z",
     "shell.execute_reply": "2023-08-18T06:56:35.654143Z"
    },
    "origin_pos": 12,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal', 'tall', 'fa', 'fas', 'fast', 'er', 'er_', 'tall_', 'fast_']\n"
     ]
    }
   ],
   "source": [
    "print(symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.658487Z",
     "iopub.status.busy": "2023-08-18T06:56:35.657897Z",
     "iopub.status.idle": "2023-08-18T06:56:35.662020Z",
     "shell.execute_reply": "2023-08-18T06:56:35.661268Z"
    },
    "origin_pos": 14,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['fast_', 'fast er_', 'tall_', 'tall er_']\n"
     ]
    }
   ],
   "source": [
    "print(list(token_freqs.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.665538Z",
     "iopub.status.busy": "2023-08-18T06:56:35.664918Z",
     "iopub.status.idle": "2023-08-18T06:56:35.670601Z",
     "shell.execute_reply": "2023-08-18T06:56:35.669830Z"
    },
    "origin_pos": 16,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "def segment_BPE(tokens, symbols):\n",
    "    outputs = []\n",
    "    for token in tokens:\n",
    "        start, end = 0, len(token)\n",
    "        cur_output = []\n",
    "        # 具有符号中可能最长子字的词元段\n",
    "        while start < len(token) and start < end:\n",
    "            if token[start: end] in symbols:\n",
    "                cur_output.append(token[start: end])\n",
    "                start = end\n",
    "                end = len(token)\n",
    "            else:\n",
    "                end -= 1\n",
    "        if start < len(token):\n",
    "            cur_output.append('[UNK]')\n",
    "        outputs.append(' '.join(cur_output))\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T06:56:35.674172Z",
     "iopub.status.busy": "2023-08-18T06:56:35.673554Z",
     "iopub.status.idle": "2023-08-18T06:56:35.677812Z",
     "shell.execute_reply": "2023-08-18T06:56:35.677058Z"
    },
    "origin_pos": 18,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['tall e s t _', 'fa t t er_']\n"
     ]
    }
   ],
   "source": [
    "tokens = ['tallest_', 'fatter_']\n",
    "print(segment_BPE(tokens, symbols))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
